{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Einops tutorial, part 2: deep learning\n",
    "\n",
    "Previous part of tutorial provides visual examples with numpy.\n",
    "\n",
    "## What's in this tutorial?\n",
    "\n",
    "- working with deep learning packages\n",
    "- important cases for deep learning models\n",
    "- `einsops.asnumpy` and `einops.layers`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops import rearrange, reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "x = np.random.RandomState(42).normal(size=[10, 32, 100, 200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# utility to hide answers\n",
    "from utils import guess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select your flavour\n",
    "\n",
    "Switch to the framework you're most comfortable with. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select one from 'chainer', 'gluon', 'tensorflow', 'pytorch' \n",
    "flavour = 'pytorch'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "selected pytorch backend\n"
     ]
    }
   ],
   "source": [
    "print('selected {} backend'.format(flavour))\n",
    "if flavour == 'tensorflow':\n",
    "    import tensorflow as tf\n",
    "    tape = tf.GradientTape(persistent=True)\n",
    "    tape.__enter__()\n",
    "    x = tf.Variable(x) + 0\n",
    "elif flavour == 'pytorch':\n",
    "    import torch\n",
    "    x = torch.from_numpy(x)\n",
    "    x.requires_grad = True\n",
    "elif flavour == 'chainer':\n",
    "    import chainer\n",
    "    x = chainer.Variable(x)\n",
    "else:\n",
    "    assert flavour == 'gluon'\n",
    "    import mxnet as mx\n",
    "    mx.autograd.set_recording(True)\n",
    "    x = mx.nd.array(x, dtype=x.dtype)\n",
    "    x.attach_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Tensor, torch.Size([10, 32, 100, 200]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(x), x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple computations \n",
    "\n",
    "- converting bchw to bhwc format and back is a common operation in CV\n",
    "- try to predict output shape and then check your guess!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b c h w -> b h w c')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Worked!\n",
    "\n",
    "Did you notice? Code above worked for you backend of choice. <br />\n",
    "Einops functions work with any tensor like they are native to the framework.\n",
    "\n",
    "## Backpropagation\n",
    "\n",
    "- gradients are a corner stone of deep learning\n",
    "- You can back-propagate through einops operations <br />\n",
    "  (just as with framework native operations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(320., dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "y0 = x\n",
    "y1 = reduce(y0, 'b c h w -> b c', 'max')\n",
    "y2 = rearrange(y1, 'b c -> c b')\n",
    "y3 = reduce(y2, 'c b -> ', 'sum')\n",
    "\n",
    "if flavour == 'tensorflow':\n",
    "    print(reduce(tape.gradient(y3, x), 'b c h w -> ', 'sum'))\n",
    "else:\n",
    "    y3.backward()\n",
    "    print(reduce(x.grad, 'b c h w -> ', 'sum'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Meet `einops.asnumpy`\n",
    "\n",
    "Just converts tensors to numpy (and pulls from gpu if necessary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "from einops import asnumpy\n",
    "y3_numpy = asnumpy(y3)\n",
    "\n",
    "print(type(y3_numpy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Common building blocks of deep learning\n",
    "\n",
    "Let's check how some familiar operations can be written with `einops`\n",
    "\n",
    "**Flattening** is common operation, frequently appears at the boundary\n",
    "between convolutional layers and fully connected layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 640000)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b c h w -> b (c h w)')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**space-to-depth**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 128, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b c (h h1) (w w1) -> b (h1 w1 c) h w', h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**depth-to-space** (notice that it's reverse of the previous)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 8, 200, 400)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b (c h1 w1) h w -> b c (h h1) (w w1)', h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reductions\n",
    "\n",
    "Simple **global average pooling**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = reduce(x, 'b c h w -> b c', reduction='mean')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "**max-pooling** with a kernel 2x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = reduce(x, 'b c (h h1) (w w1) -> b c h w', reduction='max', h1=2, w1=2)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 50, 100)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# you can skip names for reduced axes\n",
    "y = reduce(x, 'b c (h 2) (w 2) -> b c h w', reduction='max')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1d, 2d and 3d pooling are defined in a similar way\n",
    "\n",
    "for sequential 1-d models, you'll probably want pooling over time\n",
    "```python\n",
    "reduce(x, '(t 2) b c -> t b c', reduction='max')\n",
    "```\n",
    "\n",
    "for volumetric models, all three dimensions are pooled\n",
    "```python\n",
    "reduce(x, 'b c (x 2) (y 2) (z 2) -> b c x y z', reduction='max')\n",
    "```\n",
    "\n",
    "Uniformity is a strong point of `einops`, and you don't need specific operation for each particular case.\n",
    "\n",
    "\n",
    "### Good exercises \n",
    "\n",
    "- write a version of space-to-depth for 1d and 3d (2d is provided above)\n",
    "- write an average / max pooling for 1d models. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Squeeze and unsqueeze (expand_dims)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# models typically work only with batches, \n",
    "# so to predict a single image ...\n",
    "image = rearrange(x[0, :3], 'c h w -> h w c')\n",
    "# ... create a dummy 1-element axis ...\n",
    "y = rearrange(image, 'h w c -> () c h w')\n",
    "# ... imagine you predicted this with a convolutional network for classification,\n",
    "# we'll just flatten axes ...\n",
    "predictions = rearrange(y, 'b c h w -> b (c h w)')\n",
    "# ... finally, decompose (remove) dummy axis\n",
    "predictions = rearrange(predictions, '() classes -> classes')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## keepdims-like behavior for reductions\n",
    "\n",
    "- empty composition `()` provides dimensions of length 1, which are broadcastable.\n",
    "- alternatively, you can use just `1` to introduce new axis, that's a synonym to `()`\n",
    "\n",
    "per-channel mean-normalization for *each image*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = x - reduce(x, 'b c h w -> b c 1 1', 'mean')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "per-channel mean-normalization for *whole batch*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = x - reduce(y, 'b c h w -> 1 c 1 1', 'mean')\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stacking\n",
    "\n",
    "let's take a list of tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_tensors = list(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "New axis (one that enumerates tensors) appears *first* on the left side of expression.\n",
    "Just as if you were indexing list - first you'd get tensor by index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, 'b c h w -> b h w c')\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(100, 200, 32, 10)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# or maybe stack along last dimension?\n",
    "tensors = rearrange(list_of_tensors, 'b c h w -> h w c b')\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concatenation\n",
    "\n",
    "concatenate over the first dimension?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(1000, 200, 32)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, 'b c h w -> (b h) w c')\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "or maybe concatenate along last dimension?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(100, 200, 320)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tensors = rearrange(list_of_tensors, 'b c h w -> h w (b c)')\n",
    "guess(tensors.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shuffling within a dimension\n",
    "\n",
    "**channel shuffle** (as it is drawn in shufflenet paper)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b (g1 g2 c) h w-> b (g2 g1 c) h w', g1=4, g2=4)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "simpler version of **channel shuffle**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 32, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "y = rearrange(x, 'b (g c) h w-> b (c g) h w', g=4)\n",
    "guess(y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split a dimension\n",
    "\n",
    "Here's a super-convenient trick.\n",
    "\n",
    "Example: when a network predicts several bboxes for each position\n",
    "\n",
    "Assume we got 8 bboxes, 4 coordinates each. <br />\n",
    "To get coordinated into 4 separate variables, you move corresponding dimension to front and unpack tuple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 8, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       ".einops-answer {\n",
       "    color: transparent;\n",
       "    padding: 5px 15px;\n",
       "    background-color: #def;\n",
       "}\n",
       ".einops-answer:hover { color: blue; } \n",
       "</style>\n",
       "<h4>Answer is: <span class='einops-answer'>(10, 100, 200)</span> (hover to see)</h4>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "bbox_x, bbox_y, bbox_w, bbox_h = \\\n",
    "    rearrange(x, 'b (coord bbox) h w -> coord b bbox h w', coord=4, bbox=8)\n",
    "# now you can operate on individual variables\n",
    "max_bbox_area = reduce(bbox_w * bbox_h, 'b bbox h w -> b h w', 'max')\n",
    "guess(bbox_x.shape)\n",
    "guess(max_bbox_area.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Getting into the weeds of tensor packing\n",
    "\n",
    "you can skip this part - it explains why taking a habit of defining splits and packs explicitly\n",
    "\n",
    "when implementing custom gated activation (like GLU), split is needed:\n",
    "```python\n",
    "y1, y2 = rearrange(x, 'b (split c) h w -> split b c h w', split=2)\n",
    "result = y2 * sigmoid(y2) # or tanh\n",
    "```\n",
    "... but we could split differently\n",
    "```python\n",
    "y1, y2 = rearrange(x, 'b (c split) h w -> split b c h w', split=2)\n",
    "```\n",
    "\n",
    "- first one splits channels into consequent groups: `y1 = x[:, :x.shape[1] // 2, :, :]`\n",
    "- while second takes channels with a step: `y1 = x[:, 0::2, :, :]`\n",
    "\n",
    "This may drive to very *surprising* results when input is\n",
    "- a result of group convolution\n",
    "- a result of bidirectional LSTM/RNN\n",
    "- multi-head attention\n",
    "\n",
    "Let's focus on the second case (LSTM/RNN), since it is less obvious.\n",
    "\n",
    "For instance, cudnn concatenates LSTM outputs for forward-in-time and backward-in-time\n",
    "\n",
    "Also in pytorch GLU splits channels into consequent groups (first way)\n",
    "So when LSTM's output comes to GLU,\n",
    "- forward-in-time produces linear part, and backward-in-time produces activation ...\n",
    "- and role of directions is different, and gradients coming to two parts are different\n",
    "  - that's not what you expect from simple `GLU(BLSTM(x))`, right?\n",
    "\n",
    "`einops` notation makes such inconsistencies explicit and easy-detectable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shape parsing\n",
    "just a handy utility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from einops import parse_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convolve_2d(x):\n",
    "    # imagine we have a simple 2d convolution with padding,\n",
    "    # so output has same shape as input.\n",
    "    # Sorry for laziness, use imagination!\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imagine we are working with 3d data\n",
    "x_5d = rearrange(x, 'b c x (y z) -> b c x y z', z=20)\n",
    "# but we have only 2d convolutions. \n",
    "# That's not a problem, since we can apply\n",
    "y = rearrange(x_5d, 'b c x y z -> (b z) c x y')\n",
    "y = convolve_2d(y)\n",
    "# not just specifies additional information, but verifies that all dimensions match\n",
    "y = rearrange(y, '(b z) c x y -> b c x y z', **parse_shape(x_5d, 'b c x y z'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'b': 10, 'c': 32, 'x': 100, 'y': 10, 'z': 20}"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "parse_shape(x_5d, 'b c x y z')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'batch': 10, 'c': 32}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can skip some dimensions by writing underscore\n",
    "parse_shape(x_5d, 'batch c _ _ _')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Striding anything\n",
    "\n",
    "Finally, how to convert any operation into a strided operation? <br />\n",
    "(like convolution with strides, aka dilated/atrous convolution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# each image is split into subgrids, each is now a separate \"image\"\n",
    "y = rearrange(x, 'b c (h hs) (w ws) -> (hs ws b) c h w', hs=2, ws=2)\n",
    "y = convolve_2d(y)\n",
    "# pack subgrids back to an image\n",
    "y = rearrange(y, '(hs ws b) c h w -> b c (h hs) (w ws)', hs=2, ws=2)\n",
    "\n",
    "assert y.shape == x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layers\n",
    "\n",
    "For frameworks that prefer operating with layers, layers are available.\n",
    "\n",
    "You'll need to import a proper one depending on your backend:\n",
    "\n",
    "```python\n",
    "from einops.layers.chainer import Rearrange, Reduce\n",
    "from einops.layers.gluon import Rearrange, Reduce\n",
    "from einops.layers.keras import Rearrange, Reduce\n",
    "from einops.layers.torch import Rearrange, Reduce\n",
    "```\n",
    "\n",
    "`Einops` layers are identical to operations, and have same parameters. <br />\n",
    "(for the exception of first argument, which should be passed during call)\n",
    "\n",
    "```python\n",
    "layer = Rearrange(pattern, **axes_lengths)\n",
    "layer = Reduce(pattern, reduction, **axes_lengths)\n",
    "\n",
    "# apply layer to tensor\n",
    "x = layer(x)\n",
    "```\n",
    "\n",
    "Usually it is more convenient to use layers, not operations, to build models\n",
    "```python\n",
    "# example given for pytorch, but code in other frameworks is almost identical\n",
    "from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU\n",
    "from einops.layers.torch import Reduce\n",
    "\n",
    "model = Sequential(\n",
    "    Conv2d(3, 6, kernel_size=5),\n",
    "    MaxPool2d(kernel_size=2),\n",
    "    Conv2d(6, 16, kernel_size=5),\n",
    "    # combined pooling and flattening in a single step\n",
    "    Reduce('b c (h 2) (w 2) -> b (c h w)', 'max'), \n",
    "    Linear(16*5*5, 120), \n",
    "    ReLU(),\n",
    "    Linear(120, 10), \n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What's now?\n",
    "\n",
    "- rush through [writing better code with einops+pytorch](https://arogozhnikov.github.io/einops/pytorch-examples.html)\n",
    "\n",
    "Use different framework? Not a big issue, most recommendations transfer well to other frameworks. <br />\n",
    "`einops` works the same way in any framework.\n",
    "\n",
    "Finally - just write your code with einops!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
